{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# How to use few-shot prompting with tool calling\n",
    "\n",
    "For more complex tool use it's very useful to add few-shot examples to the prompt. We can do this by adding `AIMessage`s with `ToolCall`s and corresponding `ToolMessage`s to our prompt.\n",
    "\n",
    "First let's define our tools and model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import tool\n",
    "\n",
    "\n",
    "@tool\n",
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Adds a and b.\"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "@tool\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiplies a and b.\"\"\"\n",
    "    return a * b\n",
    "\n",
    "\n",
    "tools = [add, multiply]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = getpass()\n",
    "\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
    "llm_with_tools = llm.bind_tools(tools)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's run our model where we can notice that even with some special instructions our model can get tripped up by order of operations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'name': 'Multiply',\n",
       "  'args': {'a': 119, 'b': 8},\n",
       "  'id': 'call_T88XN6ECucTgbXXkyDeC2CQj'},\n",
       " {'name': 'Add',\n",
       "  'args': {'a': 952, 'b': -20},\n",
       "  'id': 'call_licdlmGsRqzup8rhqJSb1yZ4'}]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "llm_with_tools.invoke(\n",
    "    \"Whats 119 times 8 minus 20. Don't do any math yourself, only use tools for math. Respect order of operations\"\n",
    ").tool_calls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model shouldn't be trying to add anything yet, since it technically can't know the results of 119 * 8 yet.\n",
    "\n",
    "By adding a prompt with some examples we can correct this behavior:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'name': 'Multiply',\n",
       "  'args': {'a': 119, 'b': 8},\n",
       "  'id': 'call_9MvuwQqg7dlJupJcoTWiEsDo'}]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "\n",
    "examples = [\n",
    "    HumanMessage(\n",
    "        \"What's the product of 317253 and 128472 plus four\", name=\"example_user\"\n",
    "    ),\n",
    "    AIMessage(\n",
    "        \"\",\n",
    "        name=\"example_assistant\",\n",
    "        tool_calls=[\n",
    "            {\"name\": \"Multiply\", \"args\": {\"x\": 317253, \"y\": 128472}, \"id\": \"1\"}\n",
    "        ],\n",
    "    ),\n",
    "    ToolMessage(\"16505054784\", tool_call_id=\"1\"),\n",
    "    AIMessage(\n",
    "        \"\",\n",
    "        name=\"example_assistant\",\n",
    "        tool_calls=[{\"name\": \"Add\", \"args\": {\"x\": 16505054784, \"y\": 4}, \"id\": \"2\"}],\n",
    "    ),\n",
    "    ToolMessage(\"16505054788\", tool_call_id=\"2\"),\n",
    "    AIMessage(\n",
    "        \"The product of 317253 and 128472 plus four is 16505054788\",\n",
    "        name=\"example_assistant\",\n",
    "    ),\n",
    "]\n",
    "\n",
    "system = \"\"\"You are bad at math but are an expert at using a calculator. \n",
    "\n",
    "Use past tool usage as an example of how to correctly use the tools.\"\"\"\n",
    "few_shot_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"system\", system),\n",
    "        *examples,\n",
    "        (\"human\", \"{query}\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "chain = {\"query\": RunnablePassthrough()} | few_shot_prompt | llm_with_tools\n",
    "chain.invoke(\"Whats 119 times 8 minus 20\").tool_calls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we get the correct output this time.\n",
    "\n",
    "Here's what the [LangSmith trace](https://smith.langchain.com/public/f70550a1-585f-4c9d-a643-13148ab1616f/r) looks like."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
